from ssl_utils.comp_models.mtcf import WideResNet2 as MTCFWideResNet
from ssl_utils.comp_models.mtcf import WideResNet as MixMatchWideResNet
from ssl_utils.comp_models.fixmatch import build_WideResNet as FixMatchWideResNetBuilder
from ssl_utils.comp_models.ds3l import WideResNet as DS3LWideResNet
from ssl_utils.comp_models.ds3l import WNet as DS3LWNet

from utils.model_normalization import Cifar10Wrapper
import os
import torch
import torch.nn as nn
import utils.datasets as dl
import ssl_utils as ssl
from utils.load_trained_model import load_model
import pathlib
import matplotlib as mpl
import pickle
import numpy as np
import scipy.signal as signal
import os

def load_comp_model(model_arch, model_name, device, device_ids, mtcf_ood=False):
    if model_arch == 'mtcf':
        model = MTCFWideResNet(num_classes=10)
        model_dir = '/mnt/SHARED/Max/MultiTaskCurriculum/runs_proposed'
        if model_name == 'tinyImages_1M':
            model_dir = os.path.join(model_dir, 'cifar10_tiny_images@4000')
            factor = 1.
        elif model_name == 'tinyImages_10M':
            model_dir = os.path.join(model_dir, 'cifar10_tiny_images_10M@4000')
            factor = 255.
        elif model_name == 'lsun_1M':
            model_dir = os.path.join(model_dir, 'cifar10_lsun_1M@4000')
            factor = 1.
        else:
            raise NotImplementedError()

        checkpoint_file = os.path.join(model_dir, 'model_best.pth.tar')
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        model.load_state_dict(checkpoint['ema_state_dict'])

        class MTCFWrapper(nn.Module):
            def __init__(self, model, ood_test = False, only_ood=False, factor=1.):
                super().__init__()
                self.model = model
                self.ood_test = ood_test
                self.only_ood=only_ood
                self.factor = factor

            def forward(self, x, *args, **kwargs):
                if 'ood_test' not in kwargs:
                    kwargs['ood_test'] = self.ood_test
                if kwargs['ood_test']:
                    out, ood_out = self.model(self.factor * x ,*args, **kwargs)
                    if  self.only_ood:
                        return ood_out
                    else:
                        return out, ood_out

                return self.model(self.factor * x ,*args, **kwargs)


        model = MTCFWrapper(model, mtcf_ood, False, factor)

        if not model_name == 'tinyImages_10M':
            model = Cifar10Wrapper(model)
        else:
            pass
    elif model_arch == 'mixmatch':
        model = MixMatchWideResNet(num_classes=10)
        model_dir = '/mnt/SHARED/Max/MultiTaskCurriculum/runs_baseline'
        if model_name == 'tinyImages_1M':
            model_dir = os.path.join(model_dir, 'cifar10_tiny_images@4000')
        elif model_name == 'tinyImages_10M':
            model_dir = os.path.join(model_dir, 'cifar10_tiny_images_10M@4000')
        elif model_name == 'lsun_1M':
            model_dir = os.path.join(model_dir, 'cifar10_lsun_1M@4000')
        else:
            raise NotImplementedError()

        checkpoint_file = os.path.join(model_dir, 'model_best.pth.tar')
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        model.load_state_dict(checkpoint['ema_state_dict'])
        model = Cifar10Wrapper(model)
    elif model_arch == 'fixmatch':
        model = FixMatchWideResNetBuilder(28, 2, 0.01, 0.1, 0.0).build(10)
        model_dir = '/mnt/SHARED/Max/FixMatch/saved_models'
        if model_name == 'tinyImages_1M':
            model_dir = os.path.join(model_dir, 'cifar10_4000_1M')
        elif model_name == 'tinyImages_10M':
            model_dir = os.path.join(model_dir, 'cifar10_4000_10M')
        elif model_name == 'lsun_1M':
            model_dir = os.path.join(model_dir, 'lsun_4000_1M')
        else:
            raise NotImplementedError()

        checkpoint_file = os.path.join(model_dir, 'model_best.pth')
        checkpoint = torch.load(checkpoint_file, map_location='cpu')

        model.load_state_dict(checkpoint['eval_model'])
        model = Cifar10Wrapper(model)

    elif model_arch == 'ds3l':
        model = DS3LWideResNet(depth=28, widen_factor=2, n_classes=10)
        model_dir = '/mnt/SHARED/Max/DS3L/'
        if model_name == 'tinyImages_1M':
            checkpoint_file = os.path.join(model_dir, 'CIFAR10_TINYIMAGES_4000_1000000_save.pt')
        else:
            raise NotImplementedError()

        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        model.load_state_dict(checkpoint)


        def gcn(images, multiplier=55, eps=1e-10):
            # global contrast normalization
            images = images.astype(np.float)
            images -= images.mean(axis=(1, 2, 3), keepdims=True)
            per_image_norm = np.sqrt(np.square(images).sum((1, 2, 3), keepdims=True))
            per_image_norm[per_image_norm < eps] = 1
            images = multiplier * images / per_image_norm
            return images


        def get_zca_normalization_param(images, scale=0.1, eps=1e-10):
            n_data, height, width, channels = images.shape
            images = images.reshape(n_data, height * width * channels)
            image_cov = np.cov(images, rowvar=False)
            U, S, _ = np.linalg.svd(image_cov + scale * np.eye(image_cov.shape[0]))
            zca_decomp = np.dot(U, np.dot(np.diag(1 / np.sqrt(S + eps)), U.T))
            mean = images.mean(axis=0)
            return mean, zca_decomp


        def zca_normalization(images, mean, decomp):
            n_data, height, width, channels = images.shape
            images = images.reshape(n_data, -1)
            images = np.dot((images - mean), decomp)
            return images.reshape(n_data, height, width, channels)


        zca_file = os.path.join(model_dir, 'zca_cifar10.npz')
        zca_npz = np.load(zca_file)
        mean = zca_npz['mean']
        zca_decomp = zca_npz['zca_decomp']


        class DS3LWrapper(nn.Module):
            def __init__(self, model, mean, zca_decomp, multiplier=55, eps=1e-10):
                super().__init__()

                self.model = model
                self.mean = mean
                self.zca_decomp = zca_decomp
                self.multiplier = multiplier
                self.eps = eps

            def forward(self, x):
                # gcn
                # zca
                x_np = 255. * x.detach().cpu().numpy().transpose(0, 2, 3, 1)
                x_np = gcn(x_np)
                x_np = zca_normalization(x_np, self.mean, self.zca_decomp).astype(np.float32).transpose(0, 3, 1, 2)
                x_zca = torch.from_numpy(x_np).to(x.device)
                return self.model(x_zca)


        model = DS3LWrapper(model, mean, zca_decomp)
    else:
        raise NotImplementedError()

    print('\n\n###############################################')
    print(f'{model_arch} - {model_name}')

    model.to(device)
    if device_ids is not None and len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)
    model.eval()
    return model